Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various masked operations #2428

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

mazimkhan
Copy link

Introduces:

  • MaskedOrOrZero(m, a, b): returns a[i] || b[i] or zero if m[i] is false.
  • TwoTablesLookupLanesOr(d, m, a, b, unspecified): returns the result of TwoTablesLookupLanes(V a, V b, unspecified) where m[i] is true, and a[i] where m[i] is false.
  • TwoTablesLookupLanesOrZero(d, m, a, b, unspecified): returns the result of TwoTablesLookupLanes(V a, V b, unspecified) where m[i] is true, and zero where m[i] is false.
  • MaskedReduceSum(d, m, v): returns the sum of all lanes where m[i] is true.
  • MaskedReduceMin(d, m, v): returns the minimum of all lanes where m[i] is true.
  • MaskedReduceMax(d, m, v): returns the maximum of all lanes where m[i] is true.
  • IfNegativeThenNegOrUndefIfZero(mask, v): returns mask[i] < 0 ? (-v[i]) : ((mask[i] > 0) ? v[i] : impl_defined_val), where impl_defined_val is an implementation-defined value that is equal to either 0 or v[i]. SVE included only.

Testing is performed for all new operations.

Copy link

google-cla bot commented Jan 6, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@@ -1050,6 +1050,9 @@ types, and on SVE/RVV.

* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`.

* <code>V **MaskedOrOrZero**(M m, V a, V b)</code>: returns `a[i] || b[i]`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a different naming convention here which might be a bit more natural?
There is also a MaskedLoad which returns 0 as the default, as opposed to MaskedLoadOr, which has the explicit default value. If we apply that here, we can just call it MaskedOr(m, a b), what do you think?

@@ -1050,6 +1050,9 @@ types, and on SVE/RVV.

* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`.

* <code>V **MaskedOrOrZero**(M m, V a, V b)</code>: returns `a[i] || b[i]`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we mean a[i] | b[i]?

@@ -2237,6 +2240,22 @@ The following `ReverseN` must not be called if `Lanes(D()) < N`:
must be in the range `[0, 2 * Lanes(d))` but need not be unique. The index
type `TI` must be an integer of the same size as `TFromD<D>`.
* <code>V **TableLookupLanesOr**(M m, V a, V b, unspecified)</code> returns the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we don't yet have an optimized version of these op, and it's just a convenience wrapper over IfThenElse. Would it be an option to move this into a utility function within your codebase? It's not clear whether this provides enough value to be a documented op that all readers must know.

IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask
elements are false.
* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This looks useful.
Please add a TODO that we should also implement this for RVV.

#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a TODO here that we can remove the SumOfLanesM in favor of using MaskedReduceSum directly. This entails adding the D arg to HWY_SVE_REDUCE_ADD as done in HWY_SVE_FIRSTN.

@@ -4755,6 +4804,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
return IfThenElse(IsNegative(v), yes, no);
}
// ------------------------------ IfNegativeThenNegOrUndefIfZero
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This op is undocumented, do we intend to add it? If so, let's add documentation and test.

@@ -219,6 +219,15 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: we have the naming convention P for predicate, for example in HWY_SVE_RETV_ARGPVV. I'm fine with either P or M, but let's please be consistent, feel free to pick one.
This might actually replace the existing HWY_SVE_RETV_ARGPV.

}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) {
return ReduceMin(d, IfThenElse(m, v, MaxOfLanes(d, v)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unnecessarily expensive, how about we replace MaxOfLanes with Set(d, hwy::HighestValue)?

}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) {
return ReduceMax(d, IfThenElseZero(m, v));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get into trouble for signed values. If all values are negative, the presence of mask=false elements changes the result. Can similarly use hwy::LowestValue here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants